{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "104f2b97-f9bb-4dcc-a4c8-099710768851",
   "metadata": {},
   "source": [
    "# Parallel Tool use"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8dc57b6-2c48-4ee3-bb2c-25441274ed2f",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e70814b4",
   "metadata": {},
   "source": [
    "Make sure you have `ipykernel` and `pip` pre-installed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "962ae5e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -r requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e21816b3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Groq API key configured: gsk_7FdrzM...'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "from groq import Groq\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "load_dotenv()\n",
    "\"Groq API key configured: \" + os.environ[\"GROQ_API_KEY\"][:10] + \"...\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f7c9c55-e925-4cc1-89f2-58237acf14a4",
   "metadata": {},
   "source": [
    "We will use the ```llama3-70b-8192``` model in this demo. Note that you will need a Groq API Key to proceed and can create an account [here](https://console.groq.com/) to generate one for free. Only Llama 3 models support parallel tool use at this time (05/07/2024).\n",
    "\n",
    "We recommend using the 70B Llama 3 model, 8B has subpar consistency."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0cca781b-1950-4167-b36a-c1099d6b3b00",
   "metadata": {},
   "outputs": [],
   "source": [
    "client = Groq(api_key=os.getenv(\"GROQ_API_KEY\"))\n",
    "model = \"llama3-70b-8192\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c23ec2b",
   "metadata": {},
   "source": [
    "Let's define a dummy function we can invoke in our tool use loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f2ce18dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_weather(city: str):\n",
    "    if city == \"Madrid\":\n",
    "        return 35\n",
    "    elif city == \"San Francisco\":\n",
    "        return 18\n",
    "    elif city == \"Paris\":\n",
    "        return 20\n",
    "    else:\n",
    "        return 15"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a37e3c92",
   "metadata": {},
   "source": [
    "Now we define our messages and tools and run the completion request."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6b454910-4352-40cc-b9b2-cc79edabd7c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {\"role\": \"system\", \"content\": \"\"\"You are a helpful assistant.\"\"\"},\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"What is the weather in Paris, Tokyo and Madrid?\",\n",
    "    },\n",
    "]\n",
    "tools = [\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"get_weather\",\n",
    "            \"description\": \"Returns the weather in the given city in degrees Celsius\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"city\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The name of the city\",\n",
    "                    }\n",
    "                },\n",
    "                \"required\": [\"city\"],\n",
    "            },\n",
    "        },\n",
    "    }\n",
    "]\n",
    "response = client.chat.completions.create(\n",
    "    model=model, messages=messages, tools=tools, tool_choice=\"auto\", max_tokens=4096\n",
    ")\n",
    "\n",
    "response_message = response.choices[0].message"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25c2838f",
   "metadata": {},
   "source": [
    "# Processing the tool calls\n",
    "\n",
    "Now we process the assistant message and construct the required messages to continue the conversation. \n",
    "\n",
    "*Including* invoking each tool_call against our actual function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fe623ab9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[\n",
      "  {\n",
      "    \"role\": \"system\",\n",
      "    \"content\": \"You are a helpful assistant.\"\n",
      "  },\n",
      "  {\n",
      "    \"role\": \"user\",\n",
      "    \"content\": \"What is the weather in Paris, Tokyo and Madrid?\"\n",
      "  },\n",
      "  {\n",
      "    \"role\": \"assistant\",\n",
      "    \"tool_calls\": [\n",
      "      {\n",
      "        \"id\": \"call_5ak8\",\n",
      "        \"function\": {\n",
      "          \"name\": \"get_weather\",\n",
      "          \"arguments\": \"{\\\"city\\\":\\\"Paris\\\"}\"\n",
      "        },\n",
      "        \"type\": \"function\"\n",
      "      },\n",
      "      {\n",
      "        \"id\": \"call_zq26\",\n",
      "        \"function\": {\n",
      "          \"name\": \"get_weather\",\n",
      "          \"arguments\": \"{\\\"city\\\":\\\"Tokyo\\\"}\"\n",
      "        },\n",
      "        \"type\": \"function\"\n",
      "      },\n",
      "      {\n",
      "        \"id\": \"call_znf3\",\n",
      "        \"function\": {\n",
      "          \"name\": \"get_weather\",\n",
      "          \"arguments\": \"{\\\"city\\\":\\\"Madrid\\\"}\"\n",
      "        },\n",
      "        \"type\": \"function\"\n",
      "      }\n",
      "    ]\n",
      "  },\n",
      "  {\n",
      "    \"role\": \"tool\",\n",
      "    \"content\": \"20\",\n",
      "    \"tool_call_id\": \"call_5ak8\"\n",
      "  },\n",
      "  {\n",
      "    \"role\": \"tool\",\n",
      "    \"content\": \"15\",\n",
      "    \"tool_call_id\": \"call_zq26\"\n",
      "  },\n",
      "  {\n",
      "    \"role\": \"tool\",\n",
      "    \"content\": \"35\",\n",
      "    \"tool_call_id\": \"call_znf3\"\n",
      "  }\n",
      "]\n"
     ]
    }
   ],
   "source": [
    "tool_calls = response_message.tool_calls\n",
    "\n",
    "messages.append(\n",
    "    {\n",
    "        \"role\": \"assistant\",\n",
    "        \"tool_calls\": [\n",
    "            {\n",
    "                \"id\": tool_call.id,\n",
    "                \"function\": {\n",
    "                    \"name\": tool_call.function.name,\n",
    "                    \"arguments\": tool_call.function.arguments,\n",
    "                },\n",
    "                \"type\": tool_call.type,\n",
    "            }\n",
    "            for tool_call in tool_calls\n",
    "        ],\n",
    "    }\n",
    ")\n",
    "\n",
    "available_functions = {\n",
    "    \"get_weather\": get_weather,\n",
    "}\n",
    "for tool_call in tool_calls:\n",
    "    function_name = tool_call.function.name\n",
    "    function_to_call = available_functions[function_name]\n",
    "    function_args = json.loads(tool_call.function.arguments)\n",
    "    function_response = function_to_call(**function_args)\n",
    "\n",
    "    # Note how we create a separate tool call message for each tool call\n",
    "    # the model is able to discern the tool call result through the tool_call_id\n",
    "    messages.append(\n",
    "        {\n",
    "            \"role\": \"tool\",\n",
    "            \"content\": json.dumps(function_response),\n",
    "            \"tool_call_id\": tool_call.id,\n",
    "        }\n",
    "    )\n",
    "\n",
    "print(json.dumps(messages, indent=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1abe981a",
   "metadata": {},
   "source": [
    "Now we run our final completion with multiple tool call results included in the messages array.\n",
    "\n",
    "**Note**\n",
    "\n",
    "We pass the tool definitions again to help the model understand:\n",
    "\n",
    "1. The assistant message with the tool call\n",
    "2. Interpret the tool results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5f077df3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The weather in Paris is 20°C, in Tokyo is 15°C, and in Madrid is 35°C.\n"
     ]
    }
   ],
   "source": [
    "response = client.chat.completions.create(\n",
    "    model=model, messages=messages, tools=tools, tool_choice=\"auto\", max_tokens=4096\n",
    ")\n",
    "\n",
    "print(response.choices[0].message.content)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
